import numpy as np
import os
import time
import cv2
from skimage import io

def load_local_train_val(config=None):
    drivers_all = get_drivers()
    train_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDistractedDriverDetection/train'
    labels = []
    img_arrs = []
    drivers_id = [] # drivers in train fold
    for c in os.listdir(train_path):
        begin = time.time()
        flag = 0
        train_path_c = train_path + '/' + c
        label = int(c[-1])
        for img in os.listdir(train_path_c):
            drivers_id.append(drivers_all[img])
            img_path = train_path_c + '/' +img
            img_arr = cv2.imread(img_path)
            if config['crop_center'] == 'True':
                img_arr = img_arr[:,80:560]
            img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
            if config['to_gray'] == 'True':
                img_arr = cv2.cvtColor(img_arr,cv2.COLOR_RGBA2GRAY)
            #cv2.imshow('img_arr',img_arr)
            #cv2.waitKey(0)
            #cv2.destroyAllWindows()
            labels.append(label)
            img_arrs.append(img_arr)
            flag += 1
        print flag
        end = time.time()
        print 'time: {0}'.format(end-begin)
    unique_drivers = sorted(list(set(drivers_id)))
    lable_narrays = np.asarray(labels)
    image_narrays = np.asarray(img_arrs)
    if config['to_gray'] == 'True':
        image_narrays = np.expand_dims(image_narrays, 3)
    image_narrays = np.transpose(image_narrays, (0, 3, 1, 2))
    return image_narrays, lable_narrays, drivers_id, unique_drivers

def load_test(config):
    test_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDistractedDriverDetection/test'
    flag = 0
    img_arrs = []
    img_names = []
    begin = time.time()
    for img in os.listdir(test_path):
        flag += 1
        if flag % 2500 == 0:
            middel = time.time()
            print '{0} images loaded, time: {1}'.format(flag, middel-begin)
            if config['debug'] == 'True':
                break
        img_names.append(img)
        img_path = test_path + '/' + img
        img_arr = cv2.imread(img_path)
        if config['crop_center'] == 'True':
            img_arr = img_arr[:,80:560]
        img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
        if config['to_gray'] == 'True':
            img_arr = cv2.cvtColor(img_arr,cv2.COLOR_RGBA2GRAY)
        img_arrs.append(img_arr)
    image_narrays = np.asarray(img_arrs)
    if config['to_gray'] == 'True':
        image_narrays = np.expand_dims(image_narrays, 3)
    image_narrays = np.transpose(image_narrays, (0, 3, 1, 2))
    return image_narrays, img_names

def get_drivers():
    dr = dict()
    path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDistractedDriverDetection/driver_imgs_list.csv'
    print('Read drivers data')
    f = open(path, 'r')
    while True:
        line = f.readline()
        if line == '':
            break
        arr = line.strip().split(',')
        dr[arr[2]] = arr[0]
    f.close()
    return dr

def select_drivers(image_narrays, lable_narrays, drivers_id, drivers_list):
    data = []
    target = []
    index = []
    for i in range(len(drivers_id)):
        if drivers_id[i] in drivers_list:
            data.append(image_narrays[i])
            target.append(lable_narrays[i])
            index.append(i)
    data = np.asarray(data, dtype=np.float32)
    target = np.asarray(target, dtype=np.float32)
    index = np.asarray(index, dtype=np.uint8)
    return data, target, index

def load_local_train_val_DEBUG(config=None):
    drivers_all = get_drivers()
    train_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDistractedDriverDetection/train'
    labels = []
    img_arrs = []
    drivers_id = [] # drivers in train fold
    for c in os.listdir(train_path):
        begin = time.time()
        flag = 0
        train_path_c = train_path + '/' + c
        label = int(c[-1])
        for img in os.listdir(train_path_c):
            drivers_id.append(drivers_all[img])
            img_path = train_path_c + '/' +img
            img_arr = cv2.imread(img_path)
            if config['crop_center'] == 'True':
                img_arr = img_arr[:,80:560]
            img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
            if config['to_gray'] == 'True':
                img_arr = cv2.cvtColor(img_arr,cv2.COLOR_RGBA2GRAY)
            #cv2.imshow('img_arr',img_arr)
            #cv2.waitKey(0)
            #cv2.destroyAllWindows()
            labels.append(label)
            img_arrs.append(img_arr)
            if flag >= 200:
                break
            flag += 1
        print flag
        end = time.time()
        print 'time: {0}'.format(end-begin)
    unique_drivers = sorted(list(set(drivers_id)))
    lable_narrays = np.asarray(labels)
    image_narrays = np.asarray(img_arrs)
    if config['to_gray'] == 'True':
        image_narrays = np.expand_dims(image_narrays, 3)
    image_narrays = np.transpose(image_narrays, (0, 3, 1, 2))
    return image_narrays, lable_narrays, drivers_id, unique_drivers

